iT邦幫忙

2023 iThome 鐵人賽

DAY 26
0
AI & Data

30天把AI知識傳授給女友系列 第 26

Day26 建立 Pyorch 的自訂資料集和 DataLoader

  • 分享至 

  • xImage
  •  

今天介紹的內容與 Day11Day12 很像,我們需要建立資料集還有 DataLoader,首先我們先引入需要用到的套件,並且定義資料處理的流程:

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from matplotlib import pyplot as plt

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)
transform = transforms.Compose([
    transforms.CenterCrop(60),
    transforms.ToTensor(),
    normalize
])

昨天我們有透過 csv 檔取出所有圖片路徑,這個函數是用來將圖片讀近來

def default_loader(path):
    img_pil =  Image.open(path)
    img_tensor = transform(img_pil)
    return img_tensor

Day11,12 由於我們用資料夾分好,所以可以用ImageFolder,我們需要繼承 Dataset,其中有三個函數是我們需要定義的,分別為

  • __init__ :初始化的函數,其中我們會傳入所有圖片路徑、對應的標籤和讀取圖片的函數。

  • __getitem__:給 index 的時候,返回一個圖片的 tensor 和 label 的 tensor,使用了loader方法,經過前處理,從圖像變成 tensor。

  • __len__:返回資料集的長度。

程式碼如下:

class CustomDataset(Dataset):
    def __init__(self, image_path, label, loader=default_loader):
        # 定義好 image 的路徑
        self.images = image_path
        self.label = label
        self.loader = loader

    def __getitem__(self, index):
        fn = self.images[index]
        img = self.loader(fn)
        label = self.label[index]
        return img, label

    def __len__(self):
        return len(self.images)

此處我們用 CustomDataset 實例化成 train_dataset, test_dataset,然後再建立 DataLoader

train_dataset  = CustomDataset(image_path=image_path_train, label=category_label_train)
train_dataloader  = DataLoader(train_dataset, batch_size=4,shuffle=True)
test_dataset  = CustomDataset(image_path=image_path_train, label=category_label_train)
test_loader = DataLoader(test_dataset, batch_size=4,shuffle=True)

此處我們將圖片和對應的標籤顯示出來:

# 顯示圖片和對應的標籤
train_features, train_labels = next(iter(train_dataloader))
print(f"特徵矩陣大小: {train_features.size()}")
print(f"類別數量: {train_labels.size()}")
img = train_features[0]
label = train_labels[0]
new_dic = {v : k for k, v in dic.items()}
print(new_dic[label.item()])
plt.imshow(img.permute(1, 2, 0))
plt.show()
print(f"Label: {label}")

可以看出來服飾對應的圖片是正確的:

https://ithelp.ithome.com.tw/upload/images/20231001/20153503znvUOIvzzN.png

結語

建立好資料及和DataLoader 後我們就可以開始建立模型和訓練了。


上一篇
Day25 建立時尚商品的資料集
下一篇
Day27 寫程式遇到解不掉的BUG就明天再說吧~
系列文
30天把AI知識傳授給女友30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言